package com.rtsapp.server.network.protocol.crypto;


import io.netty.buffer.ByteBuf;

public class RC4 {

    private static final int N = 256;
    private short[] s = new short[N];
    private int i = 0, j = 0, t = 0;

    private static final int DEFAULT_CRYPT_SIZE = 256;
    private byte[] bytesForCrypt = new byte[DEFAULT_CRYPT_SIZE];

    public RC4(byte[] key) {

        short[] k = new short[N];
        for (int i = 0; i < N; i++) {
            s[i] = (short) i;
            k[i] = key[i % key.length];
        }

        int j = 0;
        short temp = 0;
        for (int i = 0; i < N; i++) {
            j = (j + s[i] + k[i]) % N;
            temp = s[i];
            s[i] = s[j];
            s[j] = temp;
        }
    }

    //TODO 请对比这个加密，和下面一个加密算法谁的速度快
    /**
     * 经对比 crypt0 比 crypt1 快2倍到20倍， 请用crypt0进行buff加密
     * @param buffer
     * @param begin
     * @param end
     */
    public void crypt0(ByteBuf buffer, int begin, int end) {

        while (end > begin) {

            int length = (end - begin);
            if (length > DEFAULT_CRYPT_SIZE) {
                length = DEFAULT_CRYPT_SIZE;
            }

            buffer.getBytes(begin, bytesForCrypt, 0, length);
            innteralCrypt(bytesForCrypt, length);
            buffer.setBytes(begin, bytesForCrypt, 0, length);
            begin += length;
        }

    }


    /**
     * @param buffer
     * @param begin  包括begin, 索引从0开始
     * @param end    不包括end
     */
    public void crypt1(ByteBuf buffer, int begin, int end) {
        int i = 0, j = 0, t = 0;
        int k = 0;
        byte v;
        short tmp;
        for (k = begin; k < end; k++) {
            i = (i + 1) % N;
            j = (j + s[i]) % N;
            tmp = s[i];
            s[i] = s[j];//交换s[x]和s[y]
            s[j] = tmp;
            t = (s[i] + s[j]) % N;

            v = buffer.getByte(k);
            v ^= s[t];
            buffer.setByte(k, v);
        }
    }


    /**
     * 内部加密
     *
     * @param buffer
     * @param length
     */
    private void innteralCrypt(byte[] buffer, int length) {
        int k = 0;
        short tmp;
        for (k = 0; k < length; k++) {
            i = (i + 1) % N;
            j = (j + s[i]) % N;
            tmp = s[i];
            s[i] = s[j];//交换s[x]和s[y]
            s[j] = tmp;
            t = (s[i] + s[j]) % N;
            buffer[k] ^= s[t];
        }
    }


}
